{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train DFR on Stylized ImageNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import torch\n",
    "import torchvision\n",
    "import einops\n",
    "import json\n",
    "import tqdm\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import sys\n",
    "\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_embeddings(path):\n",
    "    arr = np.load(path)\n",
    "    x, y = arr[\"embeddings\"], arr[\"labels\"]\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_logreg(\n",
    "    x_train, y_train, eval_datasets,\n",
    "    n_epochs=1000, weight_decay=0., lr=1.,\n",
    "    batch_size=1000, verbose=0, \n",
    "    n_classes=1000\n",
    "    ):\n",
    "    \n",
    "    x_train = torch.from_numpy(x_train).float()\n",
    "    y_train = torch.from_numpy(y_train).long()\n",
    "    train_ds = torch.utils.data.TensorDataset(x_train, y_train)\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        train_ds, shuffle=True, batch_size=batch_size)\n",
    "    \n",
    "    d = x_train.shape[1]\n",
    "    model = torch.nn.Linear(d, n_classes).cuda()\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    optimizer = torch.optim.SGD(\n",
    "        model.parameters(), weight_decay=weight_decay, lr=lr)\n",
    "    schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_epochs)\n",
    "    \n",
    "    for epoch in range(n_epochs):\n",
    "        correct, total = 0, 0\n",
    "        for x, y in train_loader:\n",
    "            x, y = x.cuda(), y.cuda()\n",
    "            optimizer.zero_grad()\n",
    "            pred = model(x)\n",
    "            loss = criterion(pred, y)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            schedule.step()\n",
    "            correct += (torch.argmax(pred, -1) == y).detach().float().sum().item()\n",
    "            total += len(y)\n",
    "        if verbose > 1 and ((n_epochs < 10) or epoch % (n_epochs // 10) == 0):\n",
    "            print(epoch, correct / total)\n",
    "    \n",
    "    results = {}\n",
    "    for key, (x_test, y_test) in eval_datasets.items():\n",
    "        x_test = torch.from_numpy(x_test).float().cuda()\n",
    "        pred = torch.argmax(model(x_test), axis=-1).detach().cpu().numpy()\n",
    "        results[key] = (pred == y_test).mean()\n",
    "    \n",
    "    return model, results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_data(\n",
    "    train_datasets, eval_datasets, \n",
    "    num_stylized=-1, num_original=0, preprocess=True):\n",
    "    \n",
    "    x_train, y_train = train_datasets[\"imagenet\"]\n",
    "    idx = np.arange(len(x_train))\n",
    "    np.random.shuffle(idx)\n",
    "    idx = idx[:num_original]\n",
    "    x_train = x_train[idx]\n",
    "    y_train = y_train[idx]\n",
    "\n",
    "    x_train_mr, y_train_mr = train_datasets[\"imagenet_stylized\"]\n",
    "    idx = np.arange(len(x_train_mr))\n",
    "    np.random.shuffle(idx)\n",
    "    idx = idx[:num_stylized]\n",
    "    x_train_mr = x_train_mr[idx]\n",
    "    y_train_mr = y_train_mr[idx]\n",
    "\n",
    "    x_train = np.concatenate([x_train, x_train_mr])\n",
    "    y_train = np.concatenate([y_train, y_train_mr])\n",
    "\n",
    "    if preprocess:\n",
    "        mean = x_train.mean(axis=0)[None, :]\n",
    "        std = x_train.std(axis=0)[None, :]\n",
    "        x_train = (x_train - mean) / std\n",
    "        eval_datasets_preprocessed = {\n",
    "            k: ((x - mean) / std, y)\n",
    "            for k, (x, y) in eval_datasets.items()\n",
    "        }\n",
    "    else:\n",
    "        eval_datasets_preprocessed = eval_datasets\n",
    "        mean, std = None, None\n",
    "    return x_train, y_train, eval_datasets_preprocessed, mean, std\n",
    "\n",
    "\n",
    "def run_experiment(\n",
    "    train_datasets, eval_datasets,\n",
    "    num_stylized=-1, num_original=0, preprocess=True,\n",
    "    n_epochs=10, weight_decay=0., lr=1., batch_size=1000,\n",
    "    verbose=0, num_seeds=3\n",
    "):\n",
    "    models, results = {}, {}\n",
    "    for seed in range(num_seeds):\n",
    "        x_train, y_train, eval_datasets_preprocessed, _, _ = get_data(\n",
    "            train_datasets, eval_datasets,\n",
    "            num_stylized, num_original, preprocess)\n",
    "        model, results_seed = train_logreg(\n",
    "            x_train, y_train, eval_datasets_preprocessed,\n",
    "            n_epochs, weight_decay, lr, batch_size, verbose)\n",
    "        results[seed] = results_seed\n",
    "        models[seed] = model\n",
    "    if num_seeds > 1:\n",
    "        results_aggrgated = {\n",
    "            key: (np.mean([results[seed][key] for seed in results.keys()]),\n",
    "                  np.std([results[seed][key] for seed in results.keys()]))\n",
    "            for key in results[0].keys()\n",
    "        }\n",
    "    else:\n",
    "        results_aggrgated = results[0]\n",
    "    return results, results_aggrgated, models\n",
    "\n",
    "\n",
    "def print_results(results_dict):\n",
    "    print(\"-------------------\")\n",
    "    for key, val in results_dict.items():\n",
    "        print(\"{}: {:.3f}±{:.3f}\".format(key, val[0], val[1]))\n",
    "    print(\"-------------------\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Change data paths here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "imagenet_c_corruptions = [\"brightness\", \"defocus_blur\", \"fog\", \"gaussian_blur\", \"glass_blur\",\n",
    "                          \"jpeg_compression\", \"pixelate\", \"shot_noise\", \"spatter\", \"zoom_blur\",\n",
    "                          \"contrast\", \"elastic_transform\", \"frost\", \"gaussian_noise\",\n",
    "                          \"impulse_noise\", \"motion_blur\", \"saturate\", \"snow\", \"speckle_noise\"]\n",
    "intensities = [3]\n",
    "\n",
    "eval_path_dict = {\n",
    "    \"imagenet_r\": \"/datasets/imagenet-r/imagenet-r_resnet50_val_embeddings.npz\",\n",
    "    \"imagenet_a\": \"/datasets/imagenet-a/imagenet-a_resnet50_val_embeddings.npz\",\n",
    "    \"imagenet\": \"/datasets/imagenet_symlink/resnet50_val_embeddings.npz\",\n",
    "    \"imagenet_stylized\": \"/datasets/imagenet-stylized/imagenet_resnet50_val_embeddings.npz\",\n",
    "}\n",
    "eval_datasets = {k: load_embeddings(p) for k, p in eval_path_dict.items()}\n",
    "\n",
    "train_path_dict = {\n",
    "    \"imagenet\": \"/datasets/imagenet_symlink/resnet50_train_embeddings.npz\",\n",
    "    \"imagenet_stylized\": \"/datasets/imagenet-stylized/imagenet_resnet50_train_embeddings.npz\",\n",
    "}\n",
    "train_datasets = {k: load_embeddings(p) for k, p in train_path_dict.items()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_w_b(model):\n",
    "    w = model.weight.detach().cpu().numpy()\n",
    "    b = model.bias.detach().cpu().numpy()\n",
    "    return w, b"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## IN+SIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _, mean, std = get_data(\n",
    "    train_datasets, eval_datasets, \n",
    "    num_stylized=-1, num_original=-1, preprocess=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.46593634830407893\n",
      "10 0.5358872353051408\n",
      "20 0.5469958304703604\n",
      "30 0.5548984165157099\n",
      "40 0.5609985008432756\n",
      "50 0.5658953557373977\n",
      "60 0.5694905990380411\n",
      "70 0.5725611374851646\n",
      "80 0.5757015584983447\n",
      "90 0.5789555874820413\n",
      "{'imagenet_r': 0.27166666666666667, 'imagenet_a': 0.0024, 'imagenet': 0.74524, 'imagenet_stylized': 0.21418}\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 100\n",
    "_, combo_results, models = run_experiment(train_datasets, eval_datasets,\n",
    "               num_stylized=-1, num_original=-1, num_seeds=1,\n",
    "               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)\n",
    "w, b = get_w_b(models[0])\n",
    "np.savez(f\"dfr_insin_{n_epochs}_weights.npz\",\n",
    "         w=w, b=b)\n",
    "print(combo_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_insin_100_weights.npz\")\n",
    "w, b = arr[\"w\"], arr[\"b\"]\n",
    "np.savez(\"dfr_insin_weights_bs10k.npz\",\n",
    "        w=w, b=b, preprocess_mean=mean, preprocess_std=std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Original"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _, mean, std = get_data(\n",
    "    train_datasets, eval_datasets, \n",
    "    num_stylized=0, num_original=-1, preprocess=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.7988451808357798\n",
      "10 0.8812745955400088\n",
      "20 0.8894207945530639\n",
      "30 0.8974975014054595\n",
      "40 0.8996704978449622\n",
      "50 0.9052111312386782\n",
      "60 0.9066251795864826\n",
      "70 0.9108571740895746\n",
      "80 0.9123571116247111\n",
      "90 0.9148221313011431\n",
      "{'imagenet_r': 0.2287, 'imagenet_a': 0.0004, 'imagenet': 0.75224, 'imagenet_stylized': 0.06264}\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 100\n",
    "_, original_results, models = run_experiment(train_datasets, eval_datasets,\n",
    "               num_stylized=0, num_original=-1, num_seeds=1,\n",
    "               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)\n",
    "w, b = get_w_b(models[0])\n",
    "np.savez(f\"dfr_in_{n_epochs}_weights.npz\",\n",
    "         w=w, b=b)\n",
    "print(original_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_in_100_weights.npz\")\n",
    "w, b = arr[\"w\"], arr[\"b\"]\n",
    "np.savez(\"dfr_in_weights_bs10k.npz\",\n",
    "        w=w, b=b, preprocess_mean=mean, preprocess_std=std)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, _, _, mean, std = get_data(\n",
    "    train_datasets, eval_datasets, \n",
    "    num_stylized=-1, num_original=0, preprocess=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.12500156162158785\n",
      "10 0.23137375851083766\n",
      "20 0.24725857330251733\n",
      "30 0.26579580236117184\n",
      "40 0.2706055968517709\n",
      "50 0.2846703416828034\n",
      "60 0.28752186270223\n",
      "70 0.2968400587169717\n",
      "80 0.3009861640327316\n",
      "90 0.30561246798675745\n",
      "{'imagenet_r': 0.24566666666666667, 'imagenet_a': 0.004, 'imagenet': 0.65076, 'imagenet_stylized': 0.21952}\n"
     ]
    }
   ],
   "source": [
    "n_epochs = 100\n",
    "_, stylized_results, models = run_experiment(train_datasets, eval_datasets,\n",
    "               num_stylized=-1, num_original=0, num_seeds=1,\n",
    "               n_epochs=n_epochs, weight_decay=0., verbose=2, batch_size=10000)\n",
    "w, b = get_w_b(models[0])\n",
    "np.savez(f\"dfr_sin_{n_epochs}_weights.npz\",\n",
    "         w=w, b=b)\n",
    "print(stylized_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "arr = np.load(\"dfr_sin_100_weights.npz\")\n",
    "w, b = arr[\"w\"], arr[\"b\"]\n",
    "np.savez(\"dfr_sin_weights_bs10k.npz\",\n",
    "        w=w, b=b, preprocess_mean=mean, preprocess_std=std)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py38",
   "language": "python",
   "name": "py38"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
